/*
 * Copyright 2007 Yannick Versley / Univ. Tuebingen
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package nlpeap.ml.maxent;

import java.io.EOFException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import nlpeap.ml.util.Alphabet;
import nlpeap.ml.util.DenseVector;
import nlpeap.ml.util.Minimizable;
import nlpeap.ml.util.Minimization;
import nlpeap.ml.util.SparseVector;
import riso.numerical.LBFGS.ExceptionWithIflag;

/**
 *
 * @author yannick
 */
public class ParameterEstimator implements Minimizable
{
    double C=1.0;
    protected List<List<SparseVector>> _pos_insts;
    protected List<List<SparseVector>> _neg_insts;
    
    public ParameterEstimator(int n_parameters,
            List<List<SparseVector>> pos_insts,
            List<List<SparseVector>> neg_insts) {
        _pos_insts=pos_insts;
        _neg_insts=neg_insts;
    }
    
    public double evaluateFunction(double[] parameters, double[] grad) {
        double val=C*DenseVector.dotSelf(parameters);
        double val0=val;
        for (int i=0; i<parameters.length; i++) {
            grad[i]=2.0*C*parameters[i];
        }
        double[] deriv_good=new double[grad.length];
        double[] deriv_bad=new double[grad.length];
        double m_max=0.0;
        double m_min=0.0;
        for (int i=0; i<_pos_insts.size(); i++) {
            Arrays.fill(deriv_good,0.0);
            Arrays.fill(deriv_bad,0.0);
            double mu_good=0.0;
            double mu_bad=0.0;
            for (SparseVector ex: _pos_insts.get(i)) {
                double prod=ex.dotProduct(parameters);
                if (prod>35.) {
                    m_max=Math.max(m_max,prod);
                    prod=35.;
                } else if (prod<-35.) {
                    m_min=Math.min(m_min,prod);
                    prod=-35.;
                }
                double m=Math.exp(prod);
                mu_good+=m;
                ex.addTo(deriv_good,m);
            }
            for (SparseVector ex: _neg_insts.get(i)) {
                double prod=ex.dotProduct(parameters);
                if (prod>35.) {
                    m_max=Math.max(m_max,prod);
                    prod=35.;
                } else if (prod<-35.) {
                    m_min=Math.min(m_min,prod);
                    prod=-35.;
                }
                double m=Math.exp(prod);
                mu_bad+=m;
                ex.addTo(deriv_bad,m);
            }
            //System.out.format("mu_good=%f mu_bad=%f\n",mu_good,mu_bad);
            val -= Math.log(mu_good)-Math.log(mu_bad+mu_good);
            DenseVector.plusEquals(grad,deriv_good,-1.0/mu_good);
            DenseVector.plusEquals(grad,deriv_good,1.0/(mu_good+mu_bad));
            DenseVector.plusEquals(grad,deriv_bad,1.0/(mu_good+mu_bad));
        }
        if (m_max>30.) {
            System.err.format("m<=%f",m_max);
        }
        if (m_max<-30.) {
            System.err.format("m>=%f",m_min);
        }
        System.err.format("loss=%f, perplexity=%f\n",val,
                Math.exp((val-val0)/_pos_insts.size()));
        return val;
    }

    public static void do_estimation(String prefix)
        throws FileNotFoundException, IOException, ClassNotFoundException
    {
            List<List<SparseVector>> pos_insts=
                    new ArrayList<List<SparseVector>>();
            List<List<SparseVector>> neg_insts=
                    new ArrayList<List<SparseVector>>();
            ObjectInputStream is=
                    new ObjectInputStream(new FileInputStream(prefix+".obj"));
            while (true) {
                try {
                    pos_insts.add((List<SparseVector>)is.readObject());
                    neg_insts.add((List<SparseVector>)is.readObject());
                } catch (EOFException e) {
                    break;
                }
            }
            is.close();
            is=new ObjectInputStream(new FileInputStream(prefix+".dict"));
            Alphabet dict=(Alphabet)is.readObject();
            double[] parameters=new double[dict.size()];
            try {
                ParameterEstimator estimator=
                        new ParameterEstimator(dict.size(),pos_insts, neg_insts);
                //Minimization.testGradient(parameters, estimator);
                Minimization.runLBFGS(parameters, estimator);
            } catch (ExceptionWithIflag ex) {
                ex.printStackTrace();
                throw new RuntimeException("Minimization failed",ex);
            }
            final Map<String,Double> paramValues=new HashMap<String,Double>();
            for (int i=0; i<parameters.length; i++)
            {
                paramValues.put((String)dict.lookupObject(i),
                        parameters[i]);
            }
            System.err.format("highest-weight parameters for %s\n",prefix);
            ArrayList<String> keys=new ArrayList<String>(paramValues.keySet());
            Collections.sort(keys, new Comparator<String>() {

            public int compare(String key1, String key2) {
                double delta=Math.abs(paramValues.get(key1))-
                        Math.abs(paramValues.get(key2));
                if (delta<0.0) {
                    return -1;
                } else if (delta>0.0) {
                    return +1;
                } else {
                    return 0;
                }
            }
                
            });
            for (String key: keys.subList(keys.size()-200,keys.size()))
            {
                System.err.format("%s: %f\n",key,paramValues.get(key));
            }
            ObjectOutputStream oos=new ObjectOutputStream(
                    new FileOutputStream(prefix+".param"));
            oos.writeObject(parameters);
    }
    
    public static void main(String[] args) {
        System.out.println("ParameterEstimator");
        try {
            if (args.length>0) {
                for (String s: args) {
                    do_estimation(s);
                }
            } else {
                do_estimation("models/coref/ranker_pro");
                do_estimation("models/coref/ranker_def");
                do_estimation("models/coref/ranker_app");
            }
            //TestMaximizable.testValueAndGradient(estimator);
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
    }
}